//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines transform dialect operations used for testing
// TilingInterface
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "test-tiling-interface"

#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.h.inc"

using namespace mlir;
using namespace mlir::transform;

//===----------------------------------------------------------------------===//
// TestFuseAndYieldOp
//===----------------------------------------------------------------------===//

static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
  SmallVector<Operation *> worklist;
  llvm::SmallDenseSet<Operation *> producers;
  worklist.push_back(op);
  producers.insert(op);
  while (!worklist.empty()) {
    Operation *current = worklist.pop_back_val();
    for (OpOperand &operand : current->getOpOperands()) {
      Operation *producer = operand.get().getDefiningOp();
      if (!producer || !isa<TilingInterface>(producer) ||
          producers.contains(producer))
        continue;
      worklist.push_back(producer);
      producers.insert(producer);
    }
  }
  return producers;
}

/// Apply a tile and fuse transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
static LogicalResult
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
                      Range &&payloadOps, unsigned numLoops,
                      scf::SCFTilingOptions tilingOptions,
                      TransformResults &transformResults) {
  SmallVector<Operation *> tiledOps;
  SmallVector<SmallVector<Operation *>> loopOps(numLoops);

  for (Operation *target : payloadOps) {
    auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
    if (!tilingInterfaceOp)
      return transformOp->emitError("only TilingInterface ops are supported");
    DominanceInfo dominanceInfo(tilingInterfaceOp);

    llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
        collectTiledAndFusedOps(tilingInterfaceOp);
    llvm::DenseSet<Operation *> yieldReplacementsFor;
    for (auto op : tiledAndFusedOps) {
      if (llvm::any_of(op->getUsers(), [&](Operation *user) {
            return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
          })) {
        yieldReplacementsFor.insert(op);
      }
    }

    scf::SCFTileAndFuseOptions tileAndFuseOptions;
    tileAndFuseOptions.setTilingOptions(tilingOptions);

    scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
        [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
            bool isDestinationOperand)
        -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
      Operation *owner = originalProducer.getOwner();
      bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
      return scf::SCFTileAndFuseOptions::ControlFnResult{
          yieldProducerReplacement};
    };
    tileAndFuseOptions.setFusionControlFn(controlFn);

    rewriter.setInsertionPoint(target);
    FailureOr<scf::SCFTileAndFuseResult> tiledResults =
        scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
                                                  tileAndFuseOptions);
    if (failed(tiledResults))
      return failure();

    // Perform the replacement of tiled and fused values.
    SmallVector<Operation *> opsToReplace{target};
    llvm::append_range(opsToReplace, tiledResults->fusedProducers);
    for (Operation *toReplace : opsToReplace) {
      for (OpResult res : toReplace->getResults())
        if (auto replacement = tiledResults->replacements.lookup(res)) {
          Operation *replacementOp = replacement.getDefiningOp();
          rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
            Operation *user = use.getOwner();
            return dominanceInfo.properlyDominates(replacementOp, user) &&
                   user->getParentOp() == replacementOp->getParentOp();
          });
        }

      if (toReplace->use_empty()) {
        rewriter.eraseOp(toReplace);
      }
    }

    // Report back the relevant handles to the transform op.
    tiledOps.push_back(tiledResults->tiledAndFusedOps.front());
    assert(tiledResults->loops.size() == numLoops &&
           "Mismatched number of loops, tile and fuse transform should have "
           "failed");
    for (unsigned int i = 0; i < numLoops; ++i)
      loopOps[i].push_back(tiledResults->loops[i]);
  }

  transformResults.set(transformOp->getOpResult(0), tiledOps);
  for (unsigned int i = 0; i < numLoops; ++i)
    transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);

  return success();
}

DiagnosedSilenceableFailure
transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
                                     TransformResults &transformResults,
                                     TransformState &state) {
  SmallVector<int64_t> tileSizes =
      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
  SmallVector<int64_t> tileInterchange =
      extractFromIntegerArrayAttr<int64_t>(getTileInterchange());

  SmallVector<OpFoldResult> tileSizesOfr =
      getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);

  scf::SCFTilingOptions tilingOptions;
  tilingOptions.setTileSizes(tileSizesOfr).setInterchange(tileInterchange);
  if (getUseForall()) {
    tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
  }

  LogicalResult result = applyTileAndFuseToAll(
      rewriter, getOperation(), state.getPayloadOps(getTarget()),
      tileSizes.size() - llvm::count(tileSizes, 0), tilingOptions,
      transformResults);
  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                        : DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TestFuseConsumerOp
//===----------------------------------------------------------------------===//

/// Apply fusing of consumer transformation to all payload ops and store both
/// the original consumer operation as well as the fused consumer operation.
static LogicalResult applyFuseConsumer(
    RewriterBase &rewriter, Operation *transformOp,
    ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
    uint32_t numConsumerToFuse, TransformResults &transformResults) {
  SmallVector<Operation *> originalConsumerOps;
  SmallVector<Operation *> fusedConsumerOps;

  rewriter.setInsertionPoint(slices.front());

  while (numConsumerToFuse--) {
    FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
        scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);

    if (failed(fuseConsumerResults))
      return slices.front()->emitOpError("failed to fuse consumer of slice");

    // Report back the relevant handles to the transform op.
    for (OpOperand *origConsumerOperand :
         fuseConsumerResults->origConsumerOperands) {
      originalConsumerOps.push_back(origConsumerOperand->getOwner());
    }
    for (OpOperand *tiledAndFusedConsumerOperand :
         fuseConsumerResults->tiledAndFusedConsumerOperands) {
      fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
    }
  }

  transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
  transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
  return success();
}

DiagnosedSilenceableFailure
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
                                     TransformResults &transformResults,
                                     TransformState &state) {
  SmallVector<Operation *> slices;
  for (auto op : getTargets()) {
    auto sliceOp = *state.getPayloadOps(op).begin();
    slices.push_back(sliceOp);
  }

  SmallVector<LoopLikeOpInterface> loops;
  for (auto op : llvm::reverse(getLoops())) {
    auto loopLikeOp =
        dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(op).begin());
    if (!loopLikeOp) {
      return DiagnosedSilenceableFailure::definiteFailure();
    }
    loops.push_back(loopLikeOp);
  }
  LogicalResult result =
      applyFuseConsumer(rewriter, getOperation(), slices, loops,
                        getNumConsumerToFuse(), transformResults);
  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                        : DiagnosedSilenceableFailure::success();
}

void transform::TestFuseConsumerOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  consumesHandle(getTargetsMutable(), effects);
  consumesHandle(getLoopsMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
  modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// TestTileUsingForallOp
//===----------------------------------------------------------------------===//

/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
static LogicalResult
applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
               Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
               ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
               TransformResults &transformResults) {
  SmallVector<Operation *> tiledOps;
  SmallVector<Operation *> loopOps;

  for (Operation *target : payloadOps) {
    auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
    if (!tilingInterfaceOp)
      return transformOp->emitError("only TilingInterface ops are supported");
    scf::SCFTilingOptions tilingOptions;
    tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
    if (mapping) {
      tilingOptions.setMapping(mapping.value().getValue());
    }
    tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);

    rewriter.setInsertionPoint(target);
    FailureOr<scf::SCFTilingResult> tiledResults =
        scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
    if (failed(tiledResults))
      return failure();

    // Perform the replacement of tiled and fused values.
    rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);

    // Report back the relevant handles to the transform op.
    tiledOps.push_back(tiledResults->tiledOps.front());
    for (Operation *loop : tiledResults->loops)
      loopOps.push_back(loop);
  }

  transformResults.set(transformOp->getOpResult(0), tiledOps);
  for (auto [index, loop] : llvm::enumerate(loopOps))
    transformResults.set(transformOp->getOpResult(index + 1), {loop});

  return success();
}

DiagnosedSilenceableFailure
transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
                                        TransformResults &transformResults,
                                        TransformState &state) {
  SmallVector<int64_t> tileSizes =
      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
  SmallVector<int64_t> interchange =
      extractFromIntegerArrayAttr<int64_t>(getInterchange());
  SmallVector<OpFoldResult> tileSizesOfr =
      getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);

  LogicalResult result =
      applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
                     tileSizesOfr, interchange, getMapping(), transformResults);
  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                        : DiagnosedSilenceableFailure::success();
}

void transform::TestTileUsingForallOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  consumesHandle(getTargetMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
  modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// TestFuseUsingForallOp
//===----------------------------------------------------------------------===//

/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
static LogicalResult applyTilingToAll(
    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
    unsigned numLoops, TransformResults &transformResults,
    function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
        applyFn) {
  SmallVector<Operation *> tiledLinalgOps;
  SmallVector<SmallVector<Operation *>> loopOps(1);

  for (Operation *target : payloadOps) {
    auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
    if (!tilingInterfaceOp)
      return transformOp->emitError("only TilingInterface ops are supported");

    rewriter.setInsertionPoint(target);
    FailureOr<scf::SCFTileAndFuseResult> tiledResults =
        applyFn(tilingInterfaceOp);
    if (failed(tiledResults))
      return failure();

    // Perform the replacement of tiled and fused values.
    SmallVector<Operation *> opsToReplace{target};
    llvm::append_range(opsToReplace, tiledResults->fusedProducers);
    for (Operation *toReplace : opsToReplace) {
      for (OpResult res : toReplace->getResults())
        if (auto replacement = tiledResults->replacements.lookup(res))
          rewriter.replaceAllUsesWith(res, replacement);
      if (toReplace->use_empty())
        rewriter.eraseOp(toReplace);
    }

    // Report back the relevant handles to the transform op.
    tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
    assert(tiledResults->loops.size() == 1 &&
           cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
           "Mismatched number of loops, tile and fuse transform should have "
           "failed");
    loopOps[0] = {tiledResults->loops[0]};
  }

  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
  if (!loopOps.empty())
    transformResults.set(transformOp->getOpResult(1), loopOps[0]);

  return success();
}

DiagnosedSilenceableFailure
transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
                                        TransformResults &transformResults,
                                        TransformState &state) {
  SmallVector<int64_t> tileSizes =
      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
  SmallVector<int64_t> tileInterchange =
      extractFromIntegerArrayAttr<int64_t>(getInterchange());

  scf::SCFTilingOptions tilingOptions;
  tilingOptions.interchangeVector = tileInterchange;
  SmallVector<OpFoldResult> tileSizesOfr =
      getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
  tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
  scf::SCFTileAndFuseOptions tileAndFuseOptions;
  tileAndFuseOptions.tilingOptions = tilingOptions;
  LogicalResult result = applyTilingToAll(
      rewriter, getOperation(), state.getPayloadOps(getRootOp()),
      tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
      [&](TilingInterface tilingInterfaceOp)
          -> FailureOr<scf::SCFTileAndFuseResult> {
        return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
                                                    tileAndFuseOptions);
      });
  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                        : DiagnosedSilenceableFailure::success();
}

void transform::TestFuseUsingForallOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  consumesHandle(getRootOpMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
  modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// TestTileAndFuseOuterParallelPartialReduction
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
    TransformRewriter &rewriter, TransformResults &transformResults,
    TransformState &state) {
  auto target =
      dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
  if (!target) {
    emitOpError("expected root operation to implement `TilingInterface`");
    return DiagnosedSilenceableFailure::definiteFailure();
  }

  SmallVector<unsigned> reductionDims =
      extractFromIntegerArrayAttr<unsigned>(getReductionDims());
  if (reductionDims.empty()) {
    for (auto [index, iterator] :
         llvm::enumerate(target.getLoopIteratorTypes()))
      if (iterator == utils::IteratorType::reduction)
        reductionDims.push_back(index);
  }

  if (reductionDims.empty()) {
    emitOpError(
        "no reduction dimension specified or found in the target operation");
    return DiagnosedSilenceableFailure::definiteFailure();
  }

  SmallVector<int64_t> reductionTileSizes =
      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
  if (reductionTileSizes.size() != reductionDims.size()) {
    emitOpError(
        "missing tile sizes for reduction dimensions that are to be tiled");
    return DiagnosedSilenceableFailure::definiteFailure();
  }

  // Adjust tile sizes so that it corresponds to the reduction iterator types.
  SmallVector<OpFoldResult> tileSizes;
  int reductionTileSizeNum = 0;
  OpFoldResult zero = rewriter.getIndexAttr(0);
  for (auto iterator : target.getLoopIteratorTypes()) {
    if (iterator == utils::IteratorType::parallel) {
      tileSizes.push_back(zero);
      continue;
    }
    tileSizes.push_back(
        rewriter.getIndexAttr(reductionTileSizes[reductionTileSizeNum++]));
  }

  scf::SCFTilingOptions tilingOptions;
  tilingOptions.setTileSizes(tileSizes)
      .setLoopType(scf::SCFTilingOptions::LoopType::ForallOp)
      .setReductionTilingStrategy(
          ReductionTilingStrategy::PartialReductionOuterParallel)
      .setReductionDims(reductionDims);
  if (auto mapping = getMapping()) {
    tilingOptions.setMapping(getMapping().value());
  }

  LogicalResult result = applyTileAndFuseToAll(
      rewriter, getOperation(), state.getPayloadOps(getRootOp()),
      /*numLoops =*/1, tilingOptions, transformResults);

  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                        : DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TestTileAndFuseOuterParallelPartialReduction
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
    TransformRewriter &transformRewriter, TransformResults &transformResults,
    TransformState &state) {
  auto target =
      dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
  if (!target) {
    emitOpError("expected root operation to implement `TilingInterface`");
    return DiagnosedSilenceableFailure::definiteFailure();
  }

  OpFoldResult oneOfr = transformRewriter.getIndexAttr(1);

  scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn =
      [&](RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
          ArrayRef<OpFoldResult> givenTileSizes,
          ValueRange outerDestinationTensors)
      -> FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> {
    // Check that the strides are all 1 (to make it easier in the test).
    if (llvm::any_of(loopRanges, [](Range r) {
          return !isConstantIntValue(r.stride, 1);
        })) {
      return emitOpError("unable to handle loop ranges with strides != 1");
    }
    // Check number of tile sizes is equal to loop dimensions.
    if (loopRanges.size() != givenTileSizes.size()) {
      return emitOpError("expected number of tile sizes to be same as the "
                         "number of loops in the operation");
    }
    // For testing disallow any of the tile sizes being 0.
    if (llvm::any_of(givenTileSizes, isZeroInteger)) {
      return emitOpError("unhandled case of zero tile size");
    }
    // For testing, only handle tensor tiling.
    if (outerDestinationTensors.empty()) {
      return emitOpError("expected destination tensors");
    }

    // Compute the number of iterations for each of the loops.
    AffineExpr s0, s1, s2;
    bindSymbols(rewriter.getContext(), s0, s1, s2);
    AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize

    SmallVector<OpFoldResult> allNumIters;
    allNumIters.reserve(loopRanges.size());
    for (auto [loopRange, tileSize] :
         llvm::zip_equal(loopRanges, givenTileSizes)) {
      OpFoldResult numIters = affine::makeComposedFoldedAffineApply(
          rewriter, loc, numItersExpr,
          {loopRange.offset, loopRange.size, tileSize});
      allNumIters.push_back(numIters);
    }
    if (allNumIters.empty()) {
      return emitOpError("invalid empty tile sizes and loop ranges");
    }

    AffineExpr mulExpr = s0 * s1;
    OpFoldResult cumulative = oneOfr;
    for (auto numIters : allNumIters) {
      cumulative = affine::makeComposedFoldedAffineApply(
          rewriter, loc, mulExpr, {cumulative, numIters});
    }

    Value zeroVal = arith::ConstantIndexOp::create(rewriter, loc, 0);
    Value oneVal = arith::ConstantIndexOp::create(rewriter, loc, 1);
    Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, cumulative);

    SmallVector<OpFoldResult> offsets;
    SmallVector<OpFoldResult> sizes;
    SmallVector<Value> innerDestinationTensors;
    offsets.reserve(loopRanges.size());
    sizes.reserve(loopRanges.size());

    AffineExpr d0;
    bindDims(rewriter.getContext(), d0);
    AffineExpr offsetExpr = s0 + d0 * s1; // lb + iv * tileSize
    AffineMap minMap =
        AffineMap::get(1, 2, {s0 - d0, s1},
                       rewriter.getContext()); // min(ub - offset, tileSize)
    auto forOp = scf::ForOp::create(
        rewriter, loc, zeroVal, ub, oneVal, outerDestinationTensors,
        [&](OpBuilder &b, Location bodyLoc, Value linearizedIv,
            ValueRange destinations) {
          auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
              b, bodyLoc, linearizedIv, allNumIters);
          for (auto [normalizedIv, range, tileSize] : llvm::zip_equal(
                   delinearizeOp.getResults(), loopRanges, givenTileSizes)) {

            OpFoldResult normalizedIvOfr = getAsOpFoldResult(normalizedIv);
            OpFoldResult offset = affine::makeComposedFoldedAffineApply(
                b, bodyLoc, offsetExpr,
                {normalizedIvOfr, range.offset, tileSize});
            offsets.push_back(offset);

            OpFoldResult size = affine::makeComposedFoldedAffineMin(
                b, bodyLoc, minMap, {offset, range.size, tileSize});
            sizes.push_back(size);
          }
          innerDestinationTensors = llvm::to_vector(destinations);
        });
    rewriter.setInsertionPointToEnd(forOp.getBody());
    return scf::SCFTilingOptions::CustomLoopHeaderInfo{
        {cast<LoopLikeOpInterface>(forOp.getOperation())},
        offsets,
        sizes,
        innerDestinationTensors};
  };

  scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn =
      [&](RewriterBase &rewriter, Location loc,
          ArrayRef<LoopLikeOpInterface> loops, ValueRange tiledResults,
          ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
          ArrayRef<SmallVector<OpFoldResult>> resultSizes,
          ValueRange destinationTensors) -> LogicalResult {
    SmallVector<Value> yieldValues;
    yieldValues.reserve(destinationTensors.size());
    for (auto [tiledResult, offsets, sizes, destination] : llvm::zip_equal(
             tiledResults, resultOffsets, resultSizes, destinationTensors)) {
      SmallVector<OpFoldResult> strides(offsets.size(), oneOfr);
      Value insertedVal = tensor::InsertSliceOp::create(
          rewriter, loc, tiledResult, destination, offsets, sizes, strides);
      yieldValues.push_back(insertedVal);
    }
    scf::YieldOp::create(rewriter, loc, yieldValues);
    return success();
  };

  scf::SCFTilingOptions tilingOptions;
  SmallVector<int64_t> staticTileSizes =
      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
  SmallVector<OpFoldResult> tileSizes =
      getAsIndexOpFoldResult(transformRewriter.getContext(), staticTileSizes);
  tilingOptions.setTileSizes(tileSizes)
      .setLoopType(scf::SCFTilingOptions::LoopType::CustomOp)
      .setCustomLoopGenerationFns(loopHeaderFn, terminatorFn);

  OpBuilder::InsertionGuard g(transformRewriter);
  transformRewriter.setInsertionPoint(target);
  FailureOr<scf::SCFTilingResult> tiledResults =
      scf::tileUsingSCF(transformRewriter, target, tilingOptions);
  if (failed(tiledResults)) {
    return DiagnosedSilenceableFailure::definiteFailure();
  }
  transformRewriter.replaceOp(target, tiledResults->replacements);
  transformResults.set(getOperation()->getResult(0), tiledResults->tiledOps);
  transformResults.set(getOperation()->getResult(1), tiledResults->loops);

  return DiagnosedSilenceableFailure::success();
}

#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.cpp.inc"

namespace {
class TestTilingInterfaceDialectExtension
    : public transform::TransformDialectExtension<
          TestTilingInterfaceDialectExtension> {
public:
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
      TestTilingInterfaceDialectExtension)

  using Base::Base;

  void init() {
    declareDependentDialect<affine::AffineDialect>();
    declareDependentDialect<index::IndexDialect>();
    declareDependentDialect<scf::SCFDialect>();
    declareDependentDialect<tensor::TensorDialect>();

    registerTransformOps<
#define GET_OP_LIST
#include "TestTilingInterfaceTransformOps.cpp.inc"
        >();
  }
};
} // namespace

namespace test {
void registerTestTilingInterfaceTransformDialectExtension(
    DialectRegistry &registry) {
  registry.addExtensions<TestTilingInterfaceDialectExtension>();
}
} // namespace test
